{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q8vHMOtbxH4y"
      },
      "source": [
        "Copyright 2022 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "     https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4lbN_l-JNuGa"
      },
      "source": [
        "# Super-acceleration with cyclical step-sizes\n",
        "\n",
        "This colab reproduces the figures from the blog post https://fa.bianp.net/2022/cyclical/ and the paper\n",
        "\n",
        "\u003e _Super-Acceleration with Cyclical Step-sizes_, Baptiste Goujaud, Damien Scieur, Aymeric Dieuleveut, Adrien Taylor, Fabian Pedregosa. Proceedings of the 25th International Conference on Artificial Intelligence and Statistics, 2022. https://arxiv.org/pdf/2106.09687.pdf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hpOpLKWOGphv"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "import matplotlib.font_manager as fm\n",
        "\n",
        "# for nicer fonts\n",
        "!wget https://github.com/openmaptiles/fonts/raw/master/open-sans/OpenSans-Light.ttf\n",
        "fm.fontManager.ttflist += fm.createFontList(['OpenSans-Light.ttf'])\n",
        "\n",
        "# install apngasm for creating animated PNGs\n",
        "!apt-get install apngasm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_1VcgPu_kSjr"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.patches as patches\n",
        "\n",
        "from matplotlib import rcParams\n",
        "from matplotlib.ticker import StrMethodFormatter\n",
        "rcParams['font.size'] = 35\n",
        "rcParams['font.family'] = 'Open Sans'\n",
        "rcParams['font.weight'] = 'light'\n",
        "rcParams['mathtext.fontset'] = 'cm'\n",
        "\n",
        "import numpy as np\n",
        "from scipy import special\n",
        "\n",
        "# this is a color palette shared by some of the plots\n",
        "palette = [\n",
        "    '#66c2a5', '#fc8d62', '#8da0cb', '#e78ac3', '#a6d854', '#e41a1c', '#377eb8',\n",
        "    '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628', '#f781bf'\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4YUvFGukEt_r"
      },
      "source": [
        "# Cyclical Heavy Ball animation in 2D\n",
        "\n",
        "The following code generates the iterates of classical and cyclical heavy ball on a 2D problem, for easier visualization. It will generate one PNG for each iteration. These are put together on a single animated PNG with apngasm.\n",
        "\n",
        "To download the generated file, find the file on the \"Files\" tab and right click on Download."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vVOsPq03kCJJ"
      },
      "outputs": [],
      "source": [
        "n_grid = 200\n",
        "max_iter = 10\n",
        "\n",
        "x_grid = np.linspace(-2, 5, n_grid)\n",
        "X, Y = np.meshgrid(x_grid, x_grid)\n",
        "Z = np.array((X, Y)).T\n",
        "x_init = np.array([1.5, 2.5])\n",
        "\n",
        "# A hessian with very different eigenvalues\n",
        "H = np.array([[2, 0], [0, 0.2]])\n",
        "# largest and smallest eigenvalue\n",
        "L = np.linalg.eigvalsh(H).max()\n",
        "mu = np.linalg.eigvalsh(H).min()\n",
        "\n",
        "# Compute the loss on a grid of values to display with imshow\n",
        "loss_grid = (1/2) * ((Z @ H) * Z).sum(-1)\n",
        "\n",
        "# Compute the iterates of Polyak momentum\n",
        "# and store them in the array all_iterates_momentum\n",
        "xt = x_init.copy()\n",
        "all_iterates_momentum = np.zeros((max_iter, 2))\n",
        "h = (2 / (np.sqrt(L) + np.sqrt(mu))) ** 2\n",
        "m = ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) ** 2\n",
        "xt_old = xt.copy()\n",
        "for i in range(max_iter):\n",
        "  all_iterates_momentum[i] = xt[:]\n",
        "  grad_t = H @ xt\n",
        "  tmp = xt.copy()\n",
        "  if i == 0:\n",
        "    xt = xt - (2 / (L + mu)) * grad_t\n",
        "  else:\n",
        "    xt = xt - h * grad_t + m * (xt - xt_old)\n",
        "  xt_old = tmp\n",
        "\n",
        "\n",
        "# Compute the iterates of cyclical heavy ball\n",
        "# and store them in the array all_iterates_cyclical\n",
        "mu1 = mu\n",
        "L2 = L\n",
        "rho = (L2 + mu1) / (L2 - mu1)\n",
        "# here we choose a high R to have a clear super-acceleration effect\n",
        "R = 0.9\n",
        "L1 = mu + (1 - R) * (L - mu) / 2\n",
        "mu2 = L - (1 - R) * (L - mu) / 2\n",
        "m = ((np.sqrt(rho**2 - R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - R**2)) ** 2\n",
        "all_iterates_cyclical = np.zeros((max_iter, 2))\n",
        "xt = x_init.copy()\n",
        "xt_old = xt.copy()\n",
        "for i in range(max_iter):\n",
        "  all_iterates_cyclical[i] = xt[:]\n",
        "  grad_t = H @ xt\n",
        "  tmp = xt.copy()\n",
        "  if i == 0:\n",
        "    xt = xt - (2 / (L + mu)) * grad_t\n",
        "  elif i % 2 == 0:\n",
        "    # iteration is even\n",
        "    ht = (1 + m) / L1\n",
        "    xt = xt - ht * grad_t + m * (xt - xt_old)\n",
        "  elif i % 2 == 1:\n",
        "    # iteration is odd\n",
        "    ht = (1 + m) / mu2\n",
        "    xt = xt - ht * grad_t + m * (xt - xt_old)\n",
        "  xt_old = tmp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pmcv19lJH7g8"
      },
      "outputs": [],
      "source": [
        "for i in range(max_iter):\n",
        "  plt.figure(figsize=(20, 10))\n",
        "  plt.contour(X, Y, -loss_grid.T +0.05, 50, lw=5, colors='black')\n",
        "  plt.imshow(-loss_grid.T / np.max(np.abs(loss_grid)), extent=[-2, 5, -2, 5], \n",
        "           cmap='gist_heat', alpha=1)\n",
        "  plt.scatter([0], [0], color='black', s=80)\n",
        "  plt.text(0.05, 0, '$x^\\star$', color='black')\n",
        "  plt.plot(all_iterates_momentum[:i,  0], all_iterates_momentum[:i,  1],\n",
        "           c='teal', lw=3, label='Heavy Ball', marker='d',\n",
        "           markersize=10)\n",
        "  plt.plot(all_iterates_cyclical[:i,  0], all_iterates_cyclical[:i,  1],\n",
        "           c='darkred', lw=3, label='Cyclical Heavy Ball',\n",
        "           marker='^', markersize=10)\n",
        "\n",
        "  plt.ylim((-0.2, 2.8))\n",
        "  plt.xlim((-2, 2))\n",
        "  plt.xticks(())\n",
        "  plt.yticks(())\n",
        "  plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),\n",
        "          frameon=False, ncol=1, fontsize=22)\n",
        "  plt.axes().set_aspect('equal')\n",
        "\n",
        "  f_path = 'comparison_cyclical_%02d.png' % i\n",
        "  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qeDVgf1WCDXV"
      },
      "outputs": [],
      "source": [
        "# convert to animated PNG\n",
        "%%capture\n",
        "!apngasm comparison_cyclical.png comparison_cyclical_01.png 1 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-WyOMzn6jb1L"
      },
      "source": [
        "# Residual polynomial"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gv2VPQGUjeff"
      },
      "outputs": [],
      "source": [
        "# repeat the same plots but using the Polyak momentum polynomial\n",
        "\n",
        "colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33']\n",
        "\n",
        "mu, L = 0.2, 2\n",
        "\n",
        "def poly_gd(x, t):\n",
        "  step_size = 2 / (mu + L)\n",
        "  return (1 - step_size * x) ** (t)\n",
        "\n",
        "\n",
        "def poly_polyak(x, low, high, degree):\n",
        "  m = ((np.sqrt(high) - np.sqrt(low))/(np.sqrt(high) + np.sqrt(low))) ** 2\n",
        "  h = (2 / (np.sqrt(high) + np.sqrt(low))) ** 2\n",
        "  s = (1 + m - h * x) / (2 * np.sqrt(m))\n",
        "  cheb1_part = special.eval_chebyt(degree, s)\n",
        "  cheb2_part = special.eval_chebyu(degree, s)\n",
        "  return (m**(t/2)) * ((2 * m / (1 + m)) * cheb1_part + ((1-m)/(1 + m)) * cheb2_part)\n",
        "\n",
        "\n",
        "def poly_cyclical(x, mu1, L1, mu2, L2, degree):\n",
        "  rho = (L2 + mu1) / (L2 - mu1)\n",
        "  R = (mu2 - L1) / (L2 - mu1)\n",
        "  m = ((np.sqrt(rho ** 2 - R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - R**2)) ** 2\n",
        "  h0 = (1 + m) / L1\n",
        "  h1 = (1 + m) / mu2\n",
        "  tmp = (1 + m - h0 * x) * (1 + m - h1 * x) \n",
        "  s = np.sqrt(np.abs(tmp)) * np.sign(tmp)/ (2 * np.sqrt(m))\n",
        "  cheb1_part = special.eval_chebyt(degree, s)\n",
        "  cheb2_part = special.eval_chebyu(degree, s)\n",
        "  tmp = (m**(t/2)) * ((2 * m / (1 + m)) * cheb1_part + ((1-m)/(1 + m)) * cheb2_part)\n",
        "  return tmp\n",
        "\n",
        "\n",
        "x_grid = np.linspace(0, mu + L, 500)\n",
        "for t in range(2, 19, 2):\n",
        "\n",
        "  acc_y = poly_cyclical(x_grid, mu, (mu + L)/2, (mu + L)/2, L, t)\n",
        "  idx = (x_grid \u003e= mu) \u0026 (x_grid \u003c= L)\n",
        "\n",
        "  acc_cylical = poly_cyclical(x_grid, mu, mu + 0.12, L - 0.12, L, t)\n",
        "\n",
        "  f, axarr = plt.subplots(1, 1, figsize=(12, 10))\n",
        "  plt.title(\"Degree %s\" % t)\n",
        "  base_line_2, = axarr.plot(x_grid, acc_y, '--', lw=5, label='Polyak Heavy Ball $P^{Polyak}_t$', color='#ff7f0e')\n",
        "  base_line_3, = axarr.plot(x_grid, acc_cylical, lw=5, label='Cyclical Heavy Ball $P^{Cyclical}_t$', color=colors[3])\n",
        "\n",
        "  axarr.set_ylabel('$P_{t}(\\lambda)$')\n",
        "  axarr.set_xlabel('$\\lambda$')\n",
        "  \n",
        "  axarr.axvline(x=mu, color='grey',)\n",
        "  axarr.axvline(x=L, color='grey')\n",
        "  axarr.set_xticks((0.0, mu, 0.5, 1.0, 1.5, L))\n",
        "  axarr.set_xticklabels((0.0, '$\\lambda_\\min$', None, None, None, '$\\lambda_\\max$'))\n",
        "  axarr.set_yticks((-0.1, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0))\n",
        "  axarr.set_yticklabels((None, 0, None, None, None, None, 1.0))\n",
        "\n",
        "  axarr.legend(loc='upper center', frameon=False, bbox_to_anchor=(0.5, -0.1), ncol=1)\n",
        "  axarr.set_ylim((-0.1, 1))\n",
        "  axarr.grid()\n",
        "\n",
        "  f.subplots_adjust(wspace = 0.3) # pad a little\n",
        "\n",
        "  f_path = 'CyclicalResidualPolynomial%02d.png' % (t //2)\n",
        "  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')\n",
        "\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R80xDfutQN7i"
      },
      "outputs": [],
      "source": [
        "# convert to animated PNG\n",
        "%%capture\n",
        "!apngasm CyclicalResidualPolynomial.png CyclicalResidualPolynomial01.png 1 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cS75OyGhflJw"
      },
      "source": [
        "# Link function\n",
        "\n",
        "In this block we'll plot the link function of both classical and cyclical heavy ball. We'll generate different images for different input parameters, and as before, use apngasm to generate an animated PNG from them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JjlsiZMffnlY"
      },
      "outputs": [],
      "source": [
        "def sigma(x, m, h):\n",
        "  return (1 + m - h * x) / (2 * np.sqrt(m))\n",
        "\n",
        "def zeta(x, m, h0, h1):\n",
        "  idx = ((1 + m - h0 * x) \u003e 0) \u0026 ((1 + m - h1 * x) \u003e 0)\n",
        "  out = np.zeros_like(x)\n",
        "  out[idx] = np.sqrt((1 + m - h0 * x[idx]) * (1 + m - h1 * x[idx]) / (4 * m))\n",
        "  out[~idx] = -np.sqrt((1 + m - h0 * x[~idx]) * (1 + m - h1 * x[~idx]) / (4 * m))\n",
        "  return out\n",
        "\n",
        "n_grid = 1000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PC4CwyNOgCJT"
      },
      "outputs": [],
      "source": [
        "x_grid = np.linspace(0, 2, n_grid)\n",
        "\n",
        "for i, L in enumerate(np.concatenate((np.linspace(1, 2, 20), np.linspace(2, 1, 20)))):\n",
        "  m = ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) ** 2\n",
        "  h = (2 / (np.sqrt(L) + np.sqrt(mu))) ** 2\n",
        "\n",
        "  plt.plot(x_grid, sigma(x_grid, m, h), lw=3, label='link function $\\sigma$')\n",
        "\n",
        "  yy = np.linspace(mu, L)\n",
        "  plt.plot(yy, np.zeros_like(yy), lw=10, alpha=0.5, label='$\\sigma^{-1}([-1, 1])$')\n",
        "  plt.title('Constant step-size link function', fontsize=28)\n",
        "  plt.ylim(-1.5, 1.5)\n",
        "  plt.yticks((-1, 0, 1), fontsize=22)\n",
        "  plt.xlim((0, L+mu))\n",
        "  plt.xticks(())\n",
        "  plt.grid()\n",
        "  plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),\n",
        "          frameon=False, ncol=2, fontsize=22)\n",
        "\n",
        "  f_path = 'link_function_constant_%02d.png' % i\n",
        "  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')\n",
        "\n",
        "  plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dYJ-5VXl-efC"
      },
      "outputs": [],
      "source": [
        "# convert to animated PNG with\n",
        "%%capture\n",
        "!apngasm link_function_constant.png link_function_constant_01.png 1 10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HCDiis2rg7-A"
      },
      "outputs": [],
      "source": [
        "x_grid = np.linspace(0, 2, n_grid)\n",
        "\n",
        "rho = (L + mu) / (L - mu)\n",
        "for i, R in enumerate(np.concatenate((np.linspace(0, 0.5, 20), np.linspace(0.5, 0, 20)))):\n",
        "  m = (np.sqrt(rho ** 2 - R ** 2) - np.sqrt(rho ** 2 - 1)) ** 2 / (1 - R ** 2)\n",
        "  L1 = mu + (1 - R) * (L - mu) / 2\n",
        "  mu2 = L - (1 - R) * (L - mu) / 2\n",
        "  h0 = (1 + m) / L1\n",
        "  h1 = (1 + m) / mu2\n",
        "\n",
        "  plt.plot(x_grid, zeta(x_grid, m, h0, h1), lw=3, label='link function $\\zeta$')\n",
        "\n",
        "  yy = np.linspace(mu, L)\n",
        "  idx = (yy \u003e L1) \u0026 (yy \u003c mu2)\n",
        "  yy_img = np.zeros_like(yy)\n",
        "  yy_img[idx] = np.NaN\n",
        "  plt.plot(yy, yy_img, lw=10, alpha=0.5, label='$\\zeta^{-1}([-1, 1])$')\n",
        "  plt.title('Cyclical step-size link function', fontsize=28)\n",
        "  plt.ylim(-1.5, 1.5)\n",
        "  plt.yticks((-1, 0, 1), fontsize=22)\n",
        "  plt.xlim((0, L+mu))\n",
        "  plt.xticks(())\n",
        "  plt.grid()\n",
        "  plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),\n",
        "          frameon=False, ncol=2, fontsize=22)\n",
        "\n",
        "  f_path = 'link_function_cyclical_%02d.png' % i\n",
        "  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')\n",
        "\n",
        "  plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pyi-G6Xy-bA0"
      },
      "outputs": [],
      "source": [
        "# convert to animated PNG\n",
        "%%capture\n",
        "!apngasm link_function_cyclical.png link_function_cyclical_01.png 1 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_J5nHiYvjuWn"
      },
      "source": [
        "# Spectral density\n",
        "\n",
        "In this section we download the MNIST dataset and plot the Hessian eigenvalues for a quadratic objective. We'll overlay the quantities $\\mu_1, \\mu_2, L_1, L_2$ that are important for optimization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kyhU3Qm9enI5"
      },
      "outputs": [],
      "source": [
        "# load the MNIST dataset and convert to a numpy array\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "ds = tfds.load(name='mnist', split='train')\n",
        "ds_numpy = tfds.as_numpy(ds)  # Convert `tf.data.Dataset` to Python generator\n",
        "mnist_images = []\n",
        "mnist_target = []\n",
        "for ex in ds_numpy:\n",
        "  mnist_images.append(ex['image'].ravel()) \n",
        "  mnist_target.append(ex['label'])\n",
        "\n",
        "mnist_images = np.array(mnist_images).astype(np.float64) / 255.\n",
        "mnist_target = np.array(mnist_target).astype(np.float64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "neevg5r6fTTo"
      },
      "outputs": [],
      "source": [
        "H = mnist_images.T @ mnist_images\n",
        "eigs = np.linalg.eigvalsh(H)\n",
        "\n",
        "L2 = eigs[-1]\n",
        "mu1 = np.min(eigs)\n",
        "L1 = eigs[-2]\n",
        "mu2 = L2 - (L1 - mu1)\n",
        "\n",
        "print('condition number', np.min(eigs)/np.max(eigs))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iJ83SMzeekhj"
      },
      "outputs": [],
      "source": [
        "fig, axarr = plt.subplots(1, 1, figsize=(1 * 10, 1 * 8))\n",
        "\n",
        "axarr.hist(eigs / L2, 50)\n",
        "axarr.set_yscale(\"log\")\n",
        "axarr.axvline(L1 / L2, color='#4DAF4A', linestyle='--', lw=1)\n",
        "axarr.axvline(mu1 / L2, color=palette[8], linestyle='--', lw=1)\n",
        "axarr.axvline(L2 / L2, color=palette[8], linestyle='--', lw=1)\n",
        "axarr.axvline(mu2 / L2, color='#4DAF4A', linestyle='--', lw=1)\n",
        "\n",
        "axarr.text(L1 * 0.999  / L2, 500,'$L_1$', color='#4DAF4A')\n",
        "axarr.text(mu2 * 0.999  / L2, 500,'$\\mu_2$', color='#4DAF4A')\n",
        "axarr.text(L2 * 0.999 / L2, 500,'$L_2$', color=palette[8])\n",
        "axarr.text(mu1 * 0.999 / L2, 500,'$\\mu_1$', color=palette[8])\n",
        "plt.xticks(())\n",
        "\n",
        "axarr.set_ylabel(\"density\")\n",
        "axarr.set_xlabel(\"eigenvalue magnitude\")\n",
        "\n",
        "p1 = patches.FancyArrowPatch((0, 200), (1, 200), arrowstyle='\u003c-\u003e',\n",
        "                             mutation_scale=20, color=palette[8], linewidth=3)\n",
        "axarr.add_patch(p1)\n",
        "p2 = patches.FancyArrowPatch((L1 / L2, 60), (mu2 / L2, 60), arrowstyle='\u003c-\u003e',\n",
        "                             mutation_scale=20, color='#4DAF4A', linewidth=3)\n",
        "axarr.add_patch(p2)\n",
        "axarr.text(0.4, 270, r\"$L_2 - \\mu_1$\", color=palette[8], fontsize=30)\n",
        "axarr.text(0.4, 80, r\"$\\mu_2 - L_1$\", color='#4DAF4A', fontsize=30)\n",
        "axarr.text(0.33, 12, r\"$R = \\frac{~~~~~~~~~~~~~~~~}{~~~}$\", color='k', fontsize=30)\n",
        "axarr.text(0.44, 16, r\"$\\mu_2 - L_1$\", color='#4DAF4A', fontsize=30)\n",
        "axarr.text(0.44, 9, r\"$L_2 - \\mu_1$\", color=palette[8], fontsize=30)\n",
        "\n",
        "plt.tight_layout()\n",
        "\n",
        "f_path = 'spectrum_mnist.png'\n",
        "fig.savefig(f_path, dpi=300, bbox_inches = 'tight', transparent=True, fc='k', ec='k',\n",
        "            shape=\"full\")\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EDhuneyRjuhG"
      },
      "source": [
        "# Robust region of cyclical heavy ball\n",
        "\n",
        "Here we plot the growing robust region as a function of the relative gap $R$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JgkAkwZS52Qt"
      },
      "outputs": [],
      "source": [
        "# fix the problem constants. Can be changed\n",
        "# and will yield slightly different figures\n",
        "mu, L = 0.1, 2\n",
        "n_grid = 1500"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EDuOYimAf6fw"
      },
      "outputs": [],
      "source": [
        "all_m = np.linspace(1, 1e-12, n_grid)\n",
        "all_h = np.linspace(0, 1, n_grid)\n",
        "m_grid, h_grid = np.meshgrid(all_m, all_h)\n",
        "m_polyak = ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) ** 2\n",
        "\n",
        "robust_m = all_m[all_m \u003e m_polyak]\n",
        "h_polyak = (2 / (np.sqrt(L) + np.sqrt(mu))) ** 2\n",
        "rho = (L + mu) / (L-mu)\n",
        "\n",
        "\n",
        "def varphi(xi):\n",
        "  return np.abs(xi) + np.sqrt(xi**2 - 1)\n",
        "\n",
        "\n",
        "def sigma_r(x, m, r):\n",
        "  h1 = (1+m) / (0.5 * (L + mu) - r * 0.5 * (L-mu))\n",
        "  h0 = (1+m) / (0.5 * (L + mu) + r * 0.5 * (L-mu))\n",
        "  return abs((1 + m - h0 * x) * (1 + m - h1 * x) / (2 * m) - 1)\n",
        "\n",
        "\n",
        "all_R = np.linspace(0, 1.0, 30)\n",
        "for i_R, R in enumerate(np.concatenate([all_R, all_R[::-1]])):\n",
        "\n",
        "  L1 = (mu + (1-R) * (L-mu)/2 )\n",
        "  mu2 = (L - (1-R) * (L-mu)/2)\n",
        "  optimal_m = ((np.sqrt(rho**2 - R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - R**2))**2\n",
        "\n",
        "  rate = np.zeros((n_grid, n_grid))\n",
        "\n",
        "  s1 = np.abs(sigma_r(mu, m_grid, h_grid))\n",
        "  s2 = np.abs(sigma_r(L1, m_grid, h_grid))\n",
        "  s3 = np.abs(sigma_r(mu2, m_grid, h_grid))\n",
        "  s4 = np.abs(sigma_r(L, m_grid, h_grid))\n",
        "  smax = np.max((s1, s2, s3, s4), axis=0)\n",
        "  idx = (smax \u003c= 1)\n",
        "  rate[idx] = np.sqrt(m_grid[idx])\n",
        "  rate[~idx] = np.nan\n",
        "\n",
        "  plt.figure(figsize=(16, 8))\n",
        "  plt.title(f'R={R:.{2}}')\n",
        "  plt.pcolor(m_grid, h_grid, rate, vmin=0.55)\n",
        "  plt.xlabel(r'momentum $m$')\n",
        "  plt.ylabel(r'parameter $r$')\n",
        "  cbar = plt.colorbar()\n",
        "  cbar.ax.set_ylabel('asymptotic rate')\n",
        "\n",
        "  plt.grid()\n",
        "  plt.ylim((0, 1))\n",
        "  f_path = 'robust_region_cyclical_%02d.png' % i_R\n",
        "  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8y29fpn6h9m6"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!apngasm robust_region_cyclical.png robust_region_cyclical_00.png 1 5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t1JTB3E-CHNk"
      },
      "source": [
        "# Landscape\n",
        "\n",
        "Plot the convergence rate in color as a function of the two step-sizes. The rate that we display is a consequence of Theorem 3 in [the paper](https://arxiv.org/pdf/2106.09687.pdf).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fGJP-n0ool-Q"
      },
      "outputs": [],
      "source": [
        "H = mnist_images.T @ mnist_images\n",
        "\n",
        "eigs = np.linalg.eigvalsh(H)\n",
        "eigs += 1e-2 * np.max(eigs)  # regularization\n",
        "eigs /= np.max(eigs)  # normalize\n",
        "\n",
        "L2 = np.max(eigs)\n",
        "mu1 = np.min(eigs)\n",
        "\n",
        "print('condition number', np.min(eigs)/np.max(eigs))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "FG_LPiFOCNi5"
      },
      "outputs": [],
      "source": [
        "n_grid = 2000\n",
        "h_max = 12\n",
        "\n",
        "smallest_R = 0.75\n",
        "minimum_rate = ((np.sqrt(rho**2 - smallest_R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - smallest_R**2))\n",
        "\n",
        "def sigma(x, m, h0, h1):\n",
        "  return 2 *( (1 + m - h0 * x) / (2 * np.sqrt(m))) * ((1 + m - h1 * x) / (2 * np.sqrt(m))) - 1\n",
        "\n",
        "\n",
        "all_R = np.linspace(0, smallest_R, 40)\n",
        "for it, R in enumerate(np.concatenate([all_R, all_R[::-1]])):\n",
        "  fig, axarr = plt.subplots(1, 2, figsize=(2 * 10, 1 * 8))\n",
        "\n",
        "  L1 = (mu1 + L2)/2 - (L2 - mu1) * R / 2\n",
        "  mu2 = (mu1 + L2)/2 + (L2 - mu1) * R / 2\n",
        "\n",
        "  axarr[0].hist(eigs / L2, 50)\n",
        "  axarr[0].set_yscale(\"log\")\n",
        "  axarr[0].axvline(L1 / L2, color='#4DAF4A', linestyle='--', lw=1)\n",
        "  axarr[0].axvline(mu1 / L2, color=palette[8], linestyle='--', lw=1)\n",
        "  axarr[0].axvline(L2 / L2, color=palette[8], linestyle='--', lw=1)\n",
        "  axarr[0].axvline(mu2 / L2, color='#4DAF4A', linestyle='--', lw=1)\n",
        "\n",
        "\n",
        "  axarr[0].text(L1 * 0.999  / L2, 500,'$L_1$', color='#4DAF4A')\n",
        "  axarr[0].text(mu2 * 0.999  / L2, 500,'$\\mu_2$', color='#4DAF4A')\n",
        "  axarr[0].text(L2 * 0.999 / L2, 500,'$L_2$', color=palette[8])\n",
        "  axarr[0].text(mu1 * 0.999 / L2, 500,'$\\mu_1$', color=palette[8])\n",
        "  axarr[0].set_xticks(())\n",
        "\n",
        "  axarr[0].set_ylabel(\"density\")\n",
        "  axarr[0].set_xlabel(\"eigenvalue magnitude\")\n",
        "\n",
        "\n",
        "  p1 = patches.FancyArrowPatch((0, 200), (1, 200), arrowstyle='\u003c-\u003e',\n",
        "                              mutation_scale=20, color=palette[8], linewidth=3)\n",
        "  axarr[0].add_patch(p1)\n",
        "  p2 = patches.FancyArrowPatch((L1 / L2, 60), (mu2 / L2, 60), arrowstyle='\u003c-\u003e',\n",
        "                              mutation_scale=20, color='#4DAF4A', linewidth=3)\n",
        "  axarr[0].add_patch(p2)\n",
        "  axarr[0].text(0.4, 270, r\"$L_2 - \\mu_1$\", color=palette[8], fontsize=30)\n",
        "  axarr[0].text(0.4, 80, r\"$\\mu_2 - L_1$\", color='#4DAF4A', fontsize=30)\n",
        "  axarr[0].text(0.33, 12, r\"$R = \\frac{~~~~~~~~~~~~~~~~}{~~~}$ = %.2f\" % R, color='k', fontsize=30)\n",
        "  axarr[0].text(0.44, 16, r\"$\\mu_2 - L_1$\", color='#4DAF4A', fontsize=30)\n",
        "  axarr[0].text(0.44, 9, r\"$L_2 - \\mu_1$\", color=palette[8], fontsize=30)\n",
        "\n",
        "\n",
        "  rho = (L2 + mu1) / (L2 - mu1)\n",
        "\n",
        "  optimal_m = (\n",
        "      (np.sqrt(rho ** 2 - R ** 2) - np.sqrt(rho ** 2 - 1)) / \\\n",
        "      np.sqrt(1 - R**2)) ** 2\n",
        "  optimal_h0 = (1 + optimal_m) / mu2\n",
        "  optimal_h1 = (1 + optimal_m) / L1\n",
        "\n",
        "  all_h1 = np.linspace(1e-6, h_max, n_grid)\n",
        "  all_h0 = np.linspace(1e-6, h_max, n_grid)\n",
        "  h0_grid, h1_grid = np.meshgrid(all_h1, all_h1)\n",
        "\n",
        "  rate = np.zeros((n_grid, n_grid))\n",
        "\n",
        "  for i in range(n_grid):\n",
        "\n",
        "    h0 = np.min((h0_grid[i], h1_grid[i]), axis=0)\n",
        "    h1 = np.max((h0_grid[i], h1_grid[i]), axis=0)\n",
        "\n",
        "    # compute \\sigma_star from Theorem 3.1 in https://arxiv.org/pdf/2106.09687.pdf\n",
        "    tmp0 = np.abs(sigma(mu1, optimal_m, h0, h1))\n",
        "    tmp1 = np.abs(sigma(L1, optimal_m, h0, h1))\n",
        "    tmp2 = np.abs(sigma(mu2, optimal_m, h0, h1))\n",
        "    tmp3 = np.abs(sigma(L2, optimal_m, h0, h1))\n",
        "    tmp4 = np.zeros_like(tmp3)\n",
        "    idx = ((mu1 \u003c= (1+optimal_m) * (h0 + h1) / (2 * h0 * h1)) \u0026 ((1+optimal_m) * (h0 + h1) / (2 * h0 * h1) \u003c= L1)) | \\\n",
        "          ((mu2 \u003c= (1+optimal_m) * (h0 + h1) / (2 * h0 * h1)) \u0026 ((1+optimal_m) * (h0 + h1) / (2 * h0 * h1) \u003c= L2))\n",
        "    tmp4[idx] = np.abs(sigma((1+optimal_m) * (h0 + h1) / (2 * h0 * h1), optimal_m, h0, h1))[idx]\n",
        "    sigma_star = np.max((tmp0, tmp1, tmp2, tmp3, tmp4), axis=0)    \n",
        "    idx_robust = sigma_star \u003c= 1\n",
        "    rate[i, :] = np.NaN\n",
        "    rate[i, idx_robust] = np.sqrt(optimal_m)\n",
        "    idx_convergent = sigma_star \u003c= (1 + optimal_m ** 2) / (2 * optimal_m)\n",
        "    rate[i, idx_convergent] = np.sqrt(optimal_m * (sigma_star + np.sqrt(sigma_star**2 - 1)))[idx_convergent]\n",
        "\n",
        "\n",
        "  pc = axarr[1].pcolor(h0_grid, h1_grid, rate, rasterized=True, cmap='viridis', vmin=minimum_rate, vmax=1)\n",
        "  axarr[1].set_xticks(())\n",
        "  axarr[1].set_yticks(())\n",
        "  axarr[1].set_xlabel(r\"First step-size $h_0$\")\n",
        "  axarr[1].set_ylabel(r\"Second step-size $h_1$\")\n",
        "  axarr[1].spines['top'].set_visible(False)\n",
        "  axarr[1].spines['right'].set_visible(False)\n",
        "  axarr[1].scatter(optimal_h0, optimal_h1, s=400, facecolors='none', edgecolors='#d95f02', lw=3)\n",
        "  axarr[1].plot(np.linspace(0, optimal_h0, 100), optimal_h1 * np.ones(100), '--', c='#d95f02', lw=2)\n",
        "  axarr[1].plot(np.linspace(optimal_h0, optimal_h0, 100), np.linspace(0, optimal_h1, 100), '--', c='#d95f02', lw=2)\n",
        "  axarr[1].text(2, 10, r\"$\\circ$ optimal parameters\", c='#d95f02')\n",
        "\n",
        "  if it \u003e 0:\n",
        "    axarr[1].scatter(optimal_h1, optimal_h0, s=400, facecolors='none', edgecolors='#d95f02', lw=2)\n",
        "    axarr[1].plot(np.linspace(0, optimal_h1, 100), optimal_h0 * np.ones(100), '--', c='#d95f02', lw=2)\n",
        "    axarr[1].plot(np.linspace(optimal_h1, optimal_h1, 100), np.linspace(0, optimal_h0, 100), '--', c='#d95f02', lw=2)\n",
        "\n",
        "  axarr[1].set_xlim((0, None))\n",
        "  axarr[1].set_ylim((0, None))\n",
        "\n",
        "  fig.subplots_adjust(right=0.80)\n",
        "  cbar_ax = fig.add_axes([0.82, 0.15, 0.02, 0.7])\n",
        "  fig.colorbar(pc,  cax=cbar_ax, ticks=[0.8, 1])\n",
        "  cbar_ax.set_ylabel(r'asymptotic rate')\n",
        "\n",
        "  f_path = 'rate_convergence_cyclical_%02d.png' % it\n",
        "  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')\n",
        "  plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "fDOeiphBO9Xt"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!apngasm rate_convergence_cyclical.png rate_convergence_cyclical_00.png 1 3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wQNJg-IeRG6P"
      },
      "source": [
        "# Convergence rate comparison\n",
        "\n",
        "In this section we compare the asymptotic convergence rates for different condition numbers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "BXt6czc_Zp7G"
      },
      "outputs": [],
      "source": [
        "all_R = np.linspace(0, 1, endpoint=False)\n",
        "all_kappa = np.logspace(0, -4, 20)[1:]\n",
        "\n",
        "\n",
        "def cyclical_rate(kappa, R):\n",
        "    rho = (1 + kappa) / (1 - kappa)\n",
        "    if rho.shape == () and len(R) \u003e 1:\n",
        "        rho = np.array([rho]*len(R))\n",
        "    r_polyak = rho - np.sqrt(rho ** 2 - 1)\n",
        "    r_cyclical = (np.sqrt(rho ** 2 - R ** 2) - np.sqrt(rho ** 2 - 1)) / np.sqrt(1 - R ** 2)\n",
        "    r_approx = 1 - (1 - r_polyak) / np.sqrt(1 - R ** 2)\n",
        "\n",
        "    return r_polyak, r_cyclical, r_approx\n",
        "\n",
        "\n",
        "for i, kappa in enumerate(np.concatenate([all_kappa, all_kappa[::-1]])):\n",
        "    plt.figure(figsize=(8, 6))\n",
        "    plt.title(f'$\\\\kappa$ = {kappa:.{2}}')\n",
        "\n",
        "    r_polyak, r_cyclical, r_approx = cyclical_rate(kappa, all_R)\n",
        "\n",
        "    plt.plot(all_R, (1-r_polyak), lw=4, label='Polyak', marker='d', markevery=20, markersize=10)\n",
        "    plt.plot(all_R, (1-r_cyclical), lw=4, label='Cyclical', marker='^', markevery=18, markersize=10)\n",
        "    plt.plot(all_R, (1-r_approx), '--', lw=4, label='Approx', marker='s', markevery=15, markersize=10)\n",
        "\n",
        "    plt.grid()\n",
        "    plt.xlabel('R')\n",
        "    plt.ylabel('Rate factor')\n",
        "    plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.1f}')) # 2 decimal places\n",
        "\n",
        "\n",
        "    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.25), frameon=False, ncol=3, fontsize=26)\n",
        "\n",
        "    f_path = 'asymptotic_rate_%02d.png' % i\n",
        "    plt.savefig(f_path, transparent=True, dpi=50, bbox_inches='tight')\n",
        "\n",
        "    plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "H-aHMyoSW-I_"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!apngasm asymptotic_rate.png asymptotic_rate_00.png 1 5"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "cyclical_learning_rates.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
